{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n",
      "/data/users/fyx/.local/python3/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: numpy.dtype size changed, may indicate binary incompatibility. Expected 96, got 88\n",
      "  return f(*args, **kwds)\n",
      "/data/users/fyx/.local/python3/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: numpy.dtype size changed, may indicate binary incompatibility. Expected 96, got 88\n",
      "  return f(*args, **kwds)\n"
     ]
    }
   ],
   "source": [
    "import keras\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matchzoo as mz"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_pack = mz.datasets.wiki_qa.load_data('train', task='ranking')\n",
    "valid_pack = mz.datasets.wiki_qa.load_data('dev', task='ranking', filter=True)\n",
    "predict_pack = mz.datasets.wiki_qa.load_data('test', task='ranking', filter=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing text_left with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit: 100%|██████████| 2118/2118 [00:00<00:00, 8579.17it/s]\n",
      "Processing text_right with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit: 100%|██████████| 18841/18841 [00:03<00:00, 4780.26it/s]\n",
      "Processing text_right with append: 100%|██████████| 18841/18841 [00:00<00:00, 793024.40it/s]\n",
      "Building FrequencyFilterUnit from a datapack.: 100%|██████████| 18841/18841 [00:00<00:00, 130690.93it/s]\n",
      "Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 71219.51it/s] \n",
      "Processing text_left with extend: 100%|██████████| 2118/2118 [00:00<00:00, 647535.23it/s]\n",
      "Processing text_right with extend: 100%|██████████| 18841/18841 [00:00<00:00, 681901.49it/s]\n",
      "Building VocabularyUnit from a datapack.: 100%|██████████| 404415/404415 [00:00<00:00, 2776105.71it/s]\n",
      "Processing text_left with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit: 100%|██████████| 2118/2118 [00:00<00:00, 8584.20it/s]\n",
      "Processing text_right with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit: 100%|██████████| 18841/18841 [00:03<00:00, 4774.88it/s]\n",
      "Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 119072.37it/s]\n",
      "Processing text_left with transform: 100%|██████████| 2118/2118 [00:00<00:00, 94050.46it/s]\n",
      "Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 114873.94it/s]\n",
      "Processing length_left with len: 100%|██████████| 2118/2118 [00:00<00:00, 562213.52it/s]\n",
      "Processing length_right with len: 100%|██████████| 18841/18841 [00:00<00:00, 701657.54it/s]\n",
      "Processing text_left with transform: 100%|██████████| 2118/2118 [00:00<00:00, 112934.44it/s]\n",
      "Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 91195.25it/s]\n"
     ]
    }
   ],
   "source": [
    "preprocessor = mz.preprocessors.BasicPreprocessor(fixed_length_left=10, fixed_length_right=40, remove_stop_words=False)\n",
    "train_pack_processed = preprocessor.fit_transform(train_pack)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing text_left with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit: 100%|██████████| 122/122 [00:00<00:00, 8260.10it/s]\n",
      "Processing text_right with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit: 100%|██████████| 1115/1115 [00:00<00:00, 4878.05it/s]\n",
      "Processing text_right with transform: 100%|██████████| 1115/1115 [00:00<00:00, 106546.58it/s]\n",
      "Processing text_left with transform: 100%|██████████| 122/122 [00:00<00:00, 105159.29it/s]\n",
      "Processing text_right with transform: 100%|██████████| 1115/1115 [00:00<00:00, 118264.44it/s]\n",
      "Processing length_left with len: 100%|██████████| 122/122 [00:00<00:00, 142774.86it/s]\n",
      "Processing length_right with len: 100%|██████████| 1115/1115 [00:00<00:00, 551556.66it/s]\n",
      "Processing text_left with transform: 100%|██████████| 122/122 [00:00<00:00, 64568.47it/s]\n",
      "Processing text_right with transform: 100%|██████████| 1115/1115 [00:00<00:00, 74741.48it/s]\n",
      "Processing text_left with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit: 100%|██████████| 237/237 [00:00<00:00, 9089.12it/s]\n",
      "Processing text_right with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit: 100%|██████████| 2300/2300 [00:00<00:00, 4791.16it/s]\n",
      "Processing text_right with transform: 100%|██████████| 2300/2300 [00:00<00:00, 130961.68it/s]\n",
      "Processing text_left with transform: 100%|██████████| 237/237 [00:00<00:00, 159892.24it/s]\n",
      "Processing text_right with transform: 100%|██████████| 2300/2300 [00:00<00:00, 131134.36it/s]\n",
      "Processing length_left with len: 100%|██████████| 237/237 [00:00<00:00, 252938.94it/s]\n",
      "Processing length_right with len: 100%|██████████| 2300/2300 [00:00<00:00, 602028.16it/s]\n",
      "Processing text_left with transform: 100%|██████████| 237/237 [00:00<00:00, 91247.48it/s]\n",
      "Processing text_right with transform: 100%|██████████| 2300/2300 [00:00<00:00, 89652.70it/s]\n"
     ]
    }
   ],
   "source": [
    "valid_pack_processed = preprocessor.transform(valid_pack)\n",
    "predict_pack_processed = preprocessor.transform(predict_pack)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "ranking_task = mz.tasks.Ranking(loss=mz.losses.RankHingeLoss())\n",
    "ranking_task.metrics = [\n",
    "    mz.metrics.NormalizedDiscountedCumulativeGain(k=3),\n",
    "    mz.metrics.NormalizedDiscountedCumulativeGain(k=5),\n",
    "    mz.metrics.MeanAveragePrecision()\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Parameter \"name\" set to KNRM.\n"
     ]
    }
   ],
   "source": [
    "model = mz.models.KNRM()\n",
    "model.params['input_shapes'] = preprocessor.context['input_shapes']\n",
    "model.params['task'] = ranking_task\n",
    "model.params['embedding_input_dim'] = preprocessor.context['vocab_size']\n",
    "model.params['embedding_output_dim'] = 300\n",
    "model.params['embedding_trainable'] = True\n",
    "model.params['kernel_num'] = 21\n",
    "model.params['sigma'] = 0.1\n",
    "model.params['exact_sigma'] = 0.001\n",
    "model.params['optimizer'] = 'adadelta'\n",
    "model.guess_and_fill_missing_params()\n",
    "model.build()\n",
    "model.compile()\n",
    "#model.backend.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "glove_embedding = mz.datasets.embeddings.load_glove_embedding(dimension=300)\n",
    "embedding_matrix = glove_embedding.build_matrix(preprocessor.context['vocab_unit'].state['term_index'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.load_embedding_matrix(embedding_matrix)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "pred_x, pred_y = predict_pack_processed[:].unpack()\n",
    "evaluate = mz.callbacks.EvaluateAllMetrics(model, x=pred_x, y=pred_y, batch_size=len(pred_x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "102"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_generator = mz.PairDataGenerator(train_pack_processed, num_dup=2, num_neg=1, batch_size=20)\n",
    "len(train_generator)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/30\n",
      "102/102 [==============================] - 10s 100ms/step - loss: 1.0605\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.536247360913153 - normalized_discounted_cumulative_gain@5(0): 0.5963242726387614 - mean_average_precision(0): 0.5575601476305173\n",
      "Epoch 2/30\n",
      "102/102 [==============================] - 8s 76ms/step - loss: 0.7327\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5721792212823565 - normalized_discounted_cumulative_gain@5(0): 0.6183406062081045 - mean_average_precision(0): 0.5900931353296444\n",
      "Epoch 3/30\n",
      "102/102 [==============================] - 8s 76ms/step - loss: 0.5949\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5745879488231151 - normalized_discounted_cumulative_gain@5(0): 0.6351119965492307 - mean_average_precision(0): 0.5858594498483102\n",
      "Epoch 4/30\n",
      "102/102 [==============================] - 8s 75ms/step - loss: 0.5018\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.583688238595943 - normalized_discounted_cumulative_gain@5(0): 0.641777392154616 - mean_average_precision(0): 0.5910642598671249\n",
      "Epoch 5/30\n",
      "102/102 [==============================] - 8s 75ms/step - loss: 0.4301\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.6023722228591598 - normalized_discounted_cumulative_gain@5(0): 0.6523861382312431 - mean_average_precision(0): 0.6044957790226468\n",
      "Epoch 6/30\n",
      "102/102 [==============================] - 8s 74ms/step - loss: 0.3690\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5724280981434529 - normalized_discounted_cumulative_gain@5(0): 0.6301918923499736 - mean_average_precision(0): 0.5800924897922569\n",
      "Epoch 7/30\n",
      "102/102 [==============================] - 7s 73ms/step - loss: 0.3161\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5615408440892258 - normalized_discounted_cumulative_gain@5(0): 0.6194237730365307 - mean_average_precision(0): 0.5690610623176658\n",
      "Epoch 8/30\n",
      "102/102 [==============================] - 7s 73ms/step - loss: 0.2658\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.563697464868493 - normalized_discounted_cumulative_gain@5(0): 0.613101520634656 - mean_average_precision(0): 0.56312181950864\n",
      "Epoch 9/30\n",
      "102/102 [==============================] - 8s 74ms/step - loss: 0.2219\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5537013878793432 - normalized_discounted_cumulative_gain@5(0): 0.6221510297300898 - mean_average_precision(0): 0.5647489839522603\n",
      "Epoch 10/30\n",
      "102/102 [==============================] - 7s 73ms/step - loss: 0.1834\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5509698705941599 - normalized_discounted_cumulative_gain@5(0): 0.623629245078167 - mean_average_precision(0): 0.5637408708975853\n",
      "Epoch 11/30\n",
      "102/102 [==============================] - 8s 74ms/step - loss: 0.1494\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5564963942672853 - normalized_discounted_cumulative_gain@5(0): 0.6204029339135546 - mean_average_precision(0): 0.569557844439296\n",
      "Epoch 12/30\n",
      "102/102 [==============================] - 7s 73ms/step - loss: 0.1216\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.543734856289735 - normalized_discounted_cumulative_gain@5(0): 0.61158313069304 - mean_average_precision(0): 0.5550216251165617\n",
      "Epoch 13/30\n",
      "102/102 [==============================] - 7s 73ms/step - loss: 0.0987\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5441235707124268 - normalized_discounted_cumulative_gain@5(0): 0.6097137332279133 - mean_average_precision(0): 0.5545668576493274\n",
      "Epoch 14/30\n",
      "102/102 [==============================] - 8s 74ms/step - loss: 0.0767\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5371964746837083 - normalized_discounted_cumulative_gain@5(0): 0.605243446771984 - mean_average_precision(0): 0.5511582387373527\n",
      "Epoch 15/30\n",
      "102/102 [==============================] - 7s 73ms/step - loss: 0.0607\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5280359257085162 - normalized_discounted_cumulative_gain@5(0): 0.6017280544990676 - mean_average_precision(0): 0.5444697239001036\n",
      "Epoch 16/30\n",
      "102/102 [==============================] - 7s 72ms/step - loss: 0.0467\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5148355272699622 - normalized_discounted_cumulative_gain@5(0): 0.5922520479170785 - mean_average_precision(0): 0.5329816493374122\n",
      "Epoch 17/30\n",
      "102/102 [==============================] - 7s 73ms/step - loss: 0.0338\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5284690219246697 - normalized_discounted_cumulative_gain@5(0): 0.6004256047682324 - mean_average_precision(0): 0.5428607755348261\n",
      "Epoch 18/30\n",
      "102/102 [==============================] - 7s 73ms/step - loss: 0.0265\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5205780888100985 - normalized_discounted_cumulative_gain@5(0): 0.5938831572013079 - mean_average_precision(0): 0.536485919415158\n",
      "Epoch 19/30\n",
      "102/102 [==============================] - 8s 74ms/step - loss: 0.0201\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5214886487261916 - normalized_discounted_cumulative_gain@5(0): 0.5926060509240859 - mean_average_precision(0): 0.5337510027570821\n",
      "Epoch 20/30\n",
      "102/102 [==============================] - 8s 74ms/step - loss: 0.0149\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5268129504440942 - normalized_discounted_cumulative_gain@5(0): 0.5985675662210089 - mean_average_precision(0): 0.5435044181088485\n",
      "Epoch 21/30\n",
      "102/102 [==============================] - 7s 73ms/step - loss: 0.0111\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5195324002526024 - normalized_discounted_cumulative_gain@5(0): 0.5928370213385415 - mean_average_precision(0): 0.5337503536079485\n",
      "Epoch 22/30\n",
      "102/102 [==============================] - 7s 73ms/step - loss: 0.0080\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.514104739020777 - normalized_discounted_cumulative_gain@5(0): 0.5865870386477195 - mean_average_precision(0): 0.5302430099294689\n",
      "Epoch 23/30\n",
      "102/102 [==============================] - 7s 73ms/step - loss: 0.0059\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5231459873789948 - normalized_discounted_cumulative_gain@5(0): 0.5891100751736064 - mean_average_precision(0): 0.533002369650854\n",
      "Epoch 24/30\n",
      "102/102 [==============================] - 7s 71ms/step - loss: 0.0048\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5198961794269974 - normalized_discounted_cumulative_gain@5(0): 0.5881604208203274 - mean_average_precision(0): 0.5320192005110392\n",
      "Epoch 25/30\n",
      "102/102 [==============================] - 7s 73ms/step - loss: 0.0030\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5193437332093963 - normalized_discounted_cumulative_gain@5(0): 0.5858565495360579 - mean_average_precision(0): 0.5311297594317248\n",
      "Epoch 26/30\n",
      "102/102 [==============================] - 8s 75ms/step - loss: 0.0020\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5220058840683476 - normalized_discounted_cumulative_gain@5(0): 0.5901567482051235 - mean_average_precision(0): 0.5346151394993833\n",
      "Epoch 27/30\n",
      "102/102 [==============================] - 7s 72ms/step - loss: 0.0016\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.513567065502947 - normalized_discounted_cumulative_gain@5(0): 0.5845364070836736 - mean_average_precision(0): 0.528327057975861\n",
      "Epoch 28/30\n",
      "102/102 [==============================] - 8s 82ms/step - loss: 6.7244e-04\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5214386651509757 - normalized_discounted_cumulative_gain@5(0): 0.5899229427918918 - mean_average_precision(0): 0.5351685990716492\n",
      "Epoch 29/30\n",
      "102/102 [==============================] - 7s 72ms/step - loss: 3.4626e-04\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5198961794269974 - normalized_discounted_cumulative_gain@5(0): 0.5926765216445898 - mean_average_precision(0): 0.5349571008173606\n",
      "Epoch 30/30\n",
      "102/102 [==============================] - 8s 75ms/step - loss: 1.4500e-04\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.514999308449345 - normalized_discounted_cumulative_gain@5(0): 0.5795815468524761 - mean_average_precision(0): 0.5233486618958805\n"
     ]
    }
   ],
   "source": [
    "history = model.fit_generator(train_generator, epochs=30, callbacks=[evaluate], workers=30, use_multiprocessing=True)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
